#################################################################################
# Copyright (c) 2018-2021, Texas Instruments Incorporated - http://www.ti.com
# All Rights Reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
#   list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
#   this list of conditions and the following disclaimer in the documentation
#   and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
#   contributors may be used to endorse or promote products derived from
#   this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################

"""
Reference:

Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation,
Liang-Chieh Chen, Yukun Zhu, George Papandreou, Florian Schroff, and Hartwig Adam,
Google Inc., https://arxiv.org/pdf/1802.02611.pdf
"""

import numpy as np
from collections import OrderedDict
import torch
from .... import xnn


from .pixel2pixelnet import *

try: from .pixel2pixelnet_internal import *
except: pass

from ..multi_input_net import MobileNetV2TVMI4, ResNet50MI4, MobileNetV2TVNV12MI4, MobileNetV2TVGWSMI4
from .deeplabv3lite import DeepLabV3LiteDecoder, DeepLabV3Lite
from .deeplabv3lite import get_config_deeplav3lite_mnv2, deeplabv3lite_mobilenetv2_tv


__all__ = ['get_config_deeplav3lite_mnv2_gws', 'deeplabv3lite_mobilenetv2_tv_gws',
           'deeplabv3lite_mobilenetv2_tv_es32', 'deeplabv3lite_mobilenetv2_tv_mi4_es32',
		   'deeplabv3lite_mobilenetv2_tv_nv12', 'student_teacher_learner_nv12']


######################################
class DeepLabV3LiteMobileNetV2TVNV12(DeepLabV3Lite):
    def __init__(self, model_config):
        model_config = get_config_deeplav3lite_mnv2().merge_from(model_config)
        # encoder setup
        model_config_e = model_config.clone()
        base_model = MobileNetV2TVNV12MI4(model_config=model_config_e)
        # decoder setup
        super().__init__(base_model, model_config)


def deeplabv3lite_mobilenetv2_tv_nv12(model_config, pretrained=False):
    model = DeepLabV3LiteMobileNetV2TVNV12(model_config)
    num_inputs = len(model_config.input_channels)
    num_decoders = len(model_config.output_channels) if (
                model_config.num_decoders is None) else model_config.num_decoders
    if num_inputs > 1:
        change_names_dict = {'^features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
                            '^encoder.features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
                            '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
    else:
        change_names_dict = {'^features.': 'encoder.features.',
                             '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
    #

    if pretrained:
        model = xnn.utils.load_weights(model, pretrained, change_names_dict, ignore_size=True, verbose=True)

    return model, change_names_dict


class StudentTeacherDeepLabV3LiteMobileNetV2TVNV12(DeepLabV3LiteMobileNetV2TVNV12):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.teacher, _ = deeplabv3lite_mobilenetv2_tv(**kwargs)
        self.teacher.encoder.features = self.teacher.encoder.features[:1]
        self.encoder.features = self.encoder.features[:1]

        self.decoders = None
        self.teacher.decoders = None
        self.sub = xnn.layers.SubtractBlock(signed=True)

    def forward(self, x):
        x = x[0]
        pred = self.encoder.features(x)
        target = self.teacher.encoder.features(x[2])
        #diff = target - pred
        diff = self.sub((pred, target))
        return diff


def student_teacher_learner_nv12(model_config, pretrained=None):
    model_config = get_config_deeplav3lite_mnv2().merge_from(model_config)
    # encoder setup
    model_config_e = model_config.clone()
    model = StudentTeacherDeepLabV3LiteMobileNetV2TVNV12(model_config=model_config_e)
    change_names_dict = {'^encoder.features.': 'teacher.encoder.features.', '^features.': 'teacher.encoder.features.',
                         '^decoders.': 'teacher.decoders.'}

    if pretrained:
        model = xnn.utils.load_weights(model, pretrained, change_names_dict, ignore_size=True, verbose=True)

    return model, change_names_dict


###########################################
# Groupwise Seperable (GWS) convolutions
def get_config_deeplav3lite_mnv2_gws():
    model_config = get_config_deeplav3lite_mnv2()
    model_config.groupwise_sep = True
    model_config.shortcut_channels = (64, 64*5)
    model_config.shortcut_out = 56
    model_config.decoder_chan = 252
    model_config.aspp_chan = 252
    model_config.aspp_grps = 4
    model_config.fastdown = False
    
    return model_config


class DeepLabV3LiteGWSDecoder(DeepLabV3LiteDecoder):
    def __init__(self, model_config):
        super().__init__(model_config)


class DeepLabV3LiteGWS(Pixel2PixelNet):
    def __init__(self, base_model, model_config):
        model_config = get_config_deeplav3lite_mnv2_gws().merge_from(model_config)
        super().__init__(base_model, DeepLabV3LiteGWSDecoder, model_config)


def deeplabv3lite_mobilenetv2_tv_gws(model_config, pretrained=None):
    model_config = get_config_deeplav3lite_mnv2_gws().merge_from(model_config)

    #adjust shortcut channels to accomodate gropus
    flr = lambda a : (a//model_config.group_size_dws)*model_config.group_size_dws
    enc_dec_dws_ratio_lcm = int(np.lcm(model_config.group_size_dws, 4))
    flr_lcm = lambda a: (a // enc_dec_dws_ratio_lcm) * enc_dec_dws_ratio_lcm
    model_config.shortcut_channels = (flr(model_config.shortcut_channels[0]), flr_lcm(model_config.shortcut_channels[1]))

    # encoder setup
    model_config_e = model_config.clone()
    model_config_e.output_stride = np.prod(model_config_e.strides)
    base_model = MobileNetV2TVGWSMI4(model_config_e)
    # decoder setup
    # experimenting with hybrid depth wise separable i.e Ni/G is more than 1 for DWS and group size if set to Ni/G for pointwise conv.
    # Also shuffle is used in between this two layers
    model = DeepLabV3LiteGWS(base_model, model_config)

    change_names_dict = {'^features.': 'encoder.features.'}
    if pretrained:
        model = xnn.utils.load_weights(model, pretrained, change_names_dict, ignore_size=True, verbose=True)

    return model, change_names_dict


######################################
# DeepLabV3Lite, but with with encoder stride of 32
def deeplabv3lite_mobilenetv2_tv_es32(model_config, pretrained=None):
    # encoder setup
    model_config_e = model_config.clone()
    model_config_e.strides = (2,2,2,2,2)
    encoder_stride = np.prod(model_config_e.strides)
    model_config_e.shortcut_strides = (8, encoder_stride)
    base_model = MobileNetV2TVMI4(model_config_e)
    # decoder setup
    model = DeepLabV3Lite(base_model, model_config)

    change_names_dict = {'^features.': 'encoder.features.'}
    if pretrained:
        model = xnn.utils.load_weights(model, pretrained, change_names_dict, ignore_size=True, verbose=True)

    return model, change_names_dict


######################################
# DeepLabV3Lite, but with with encoder stride of 32
def deeplabv3lite_mobilenetv2_tv_mi4_es32(model_config, pretrained=None):
    # encoder setup
    model_config_e = model_config.clone()
    model_config_e.strides = (2,2,2,2,2)
    encoder_stride = np.prod(model_config_e.strides)
    model_config_e.shortcut_strides = (8, encoder_stride)
    base_model = MobileNetV2TVMI4(model_config_e)
    # decoder setup
    model = DeepLabV3Lite(base_model, model_config)

    num_inputs = len(model_config.input_channels)
    num_decoders = len(model_config.output_channels) if (model_config.num_decoders is None) else model_config.num_decoders
    change_names_dict = {'^features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
                         '^encoder.features.': ['encoder.features.stream{}.'.format(stream) for stream in
                                                range(num_inputs)],
                         '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}

    if pretrained:
        model = xnn.utils.load_weights(model, pretrained, change_names_dict, ignore_size=True, verbose=True)

    return model, change_names_dict

